#!/usr/bin/env python3
# -*- coding: utf-8 -*-

from __future__ import annotations

import argparse
import json
from dataclasses import dataclass
from pathlib import Path
import sys

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

_HERE = Path(__file__).resolve().parent
if str(_HERE) not in sys.path:
    sys.path.insert(0, str(_HERE))

from analyze_env_policy_effects import (  # type: ignore
    _pilot_year_series,
    _treat_series,
    prepare_panel,
    twfe_cluster,
)


def _two_way_demean(values: np.ndarray, city_codes: np.ndarray, years: np.ndarray, max_iter: int = 50) -> np.ndarray:
    v = values.astype(float).copy()
    city_codes = city_codes.astype(str)
    years = years.astype(str)

    # Alternating projections onto the orthogonal complement of the city and year
    # fixed-effect spaces.
    for _ in range(max_iter):
        v_prev = v.copy()

        df = pd.DataFrame({"v": v, "g1": city_codes, "g2": years})
        v = v - df.groupby("g1")["v"].transform("mean").to_numpy()
        df["v"] = v
        v = v - df.groupby("g2")["v"].transform("mean").to_numpy()

        if np.nanmax(np.abs(v - v_prev)) < 1e-10:
            break
    return v


def _cluster_se_single_regressor(x: np.ndarray, y: np.ndarray, *, beta: float, clusters: np.ndarray) -> float:
    x = x.astype(float)
    y = y.astype(float)
    clusters = clusters.astype(str)

    resid = y - beta * x
    df = pd.DataFrame({"x": x, "u": resid, "g": clusters})
    s = df.assign(xu=df["x"] * df["u"]).groupby("g")["xu"].sum().to_numpy()
    meat = float(np.dot(s, s))
    xx = float(np.dot(x, x))
    if xx <= 0:
        return float("nan")

    n = int(len(x))
    g = int(df["g"].nunique())
    k = 1

    # Finite-sample correction commonly used for clustered SE.
    scale = (g / (g - 1)) * ((n - 1) / (n - k)) if g > 1 and n > k else 1.0
    var = scale * meat / (xx * xx)
    return float(np.sqrt(var)) if var >= 0 else float("nan")


@dataclass(frozen=True)
class StackedDidResult:
    coef: float
    se: float
    nobs: int
    cohorts: list[int]
    window: tuple[int, int]


def estimate_stacked_did(
    df: pd.DataFrame,
    *,
    policy: str,
    outcome: str,
    window: tuple[int, int],
) -> StackedDidResult:
    treat = _treat_series(df, policy)
    py = _pilot_year_series(df, policy)
    cohorts = sorted(int(x) for x in py[treat == 1].dropna().unique().tolist())

    stacks: list[pd.DataFrame] = []
    for g in cohorts:
        treated_cities = set(df.loc[py == g, "city_code6"].unique().tolist())
        # Use never-treated cities as controls to avoid contamination by other cohorts.
        never_cities = set(df.loc[treat == 0, "city_code6"].unique().tolist())
        keep_cities = treated_cities | never_cities

        lo, hi = g + window[0], g + window[1]
        d = df[df["city_code6"].isin(keep_cities) & df["year"].between(lo, hi)].copy()
        d["stack_cohort"] = g
        d["stack_city"] = d["city_code6"].astype(str) + f"_{g}"
        d["stack_year"] = d["year"].astype(int).astype(str) + f"_{g}"
        d["stack_treat"] = d["city_code6"].isin(treated_cities).astype(int)
        d["stack_post"] = (d["year"] >= g).astype(int)
        d["stack_did"] = (d["stack_treat"] * d["stack_post"]).astype(int)
        stacks.append(d)

    stacked = pd.concat(stacks, ignore_index=True)
    stacked = stacked.dropna(subset=[outcome]).copy()

    y = pd.to_numeric(stacked[outcome], errors="coerce").to_numpy()
    x = pd.to_numeric(stacked["stack_did"], errors="coerce").fillna(0).astype(int).to_numpy()

    # Residualize w.r.t. stack-specific city and year FE, then run a single-regressor
    # regression with clustering at the original city level.
    y_tilde = _two_way_demean(y, stacked["stack_city"].to_numpy(), stacked["stack_year"].to_numpy())
    x_tilde = _two_way_demean(x, stacked["stack_city"].to_numpy(), stacked["stack_year"].to_numpy())
    denom = float(np.dot(x_tilde, x_tilde))
    coef = float(np.dot(x_tilde, y_tilde) / denom) if denom > 0 else float("nan")
    se = _cluster_se_single_regressor(x_tilde, y_tilde, beta=coef, clusters=stacked["city_code6"].to_numpy())
    return StackedDidResult(coef=coef, se=se, nobs=int(len(stacked)), cohorts=cohorts, window=window)


def placebo_shift_did_pre_sample(
    df: pd.DataFrame,
    *,
    policy: str,
    outcome: str,
    shift_years: int,
) -> tuple[float, float, int]:
    treat = _treat_series(df, policy)
    py = _pilot_year_series(df, policy)

    # Keep only pre-treatment observations for treated cities.
    pre = df[(treat == 0) | (df["year"] < py)].copy()

    placebo_py = np.where(treat.loc[pre.index].to_numpy() == 1, py.loc[pre.index].to_numpy() - shift_years, 9999)
    placebo_post = (pre["year"].to_numpy() >= placebo_py).astype(int)
    pre["placebo_did"] = (treat.loc[pre.index].to_numpy() * placebo_post).astype(int)

    return twfe_cluster(pre.dropna(subset=[outcome]), formula=f"{outcome} ~ placebo_did + C(city_code6) + C(year)", coef="placebo_did")


def permute_pilot_years(
    df: pd.DataFrame,
    *,
    policy: str,
    outcome: str,
    n_perm: int,
    seed: int,
) -> tuple[float, float, float, np.ndarray]:
    rng = np.random.default_rng(seed)

    treat = _treat_series(df, policy).astype(int)
    py = _pilot_year_series(df, policy)

    d = df.dropna(subset=[outcome]).copy()
    treat_d = treat.loc[d.index].to_numpy()
    py_d = py.loc[d.index].to_numpy()

    city = d["city_code6"].astype(str).to_numpy()
    year_int = d["year"].astype(int).to_numpy()
    year_group = d["year"].astype(int).astype(str).to_numpy()
    y = pd.to_numeric(d[outcome], errors="coerce").to_numpy()

    y_tilde = _two_way_demean(y, city, year_group)

    # Pilot-year assignment at the city level.
    treated_cities = df.loc[treat == 1, "city_code6"].unique().tolist()
    pilot_by_city = (
        df.loc[treat == 1, ["city_code6", f"{policy}_pilot_year"]]
        .drop_duplicates(subset=["city_code6"])
        .set_index("city_code6")[f"{policy}_pilot_year"]
    )
    pilot_values = pilot_by_city.loc[treated_cities].astype(int).to_numpy()

    # Observed DID term under true pilot years (computed at observation level).
    py_d_int = np.where(np.isfinite(py_d), py_d.astype(int), 9999)
    post_true = np.where(treat_d == 1, (year_int >= py_d_int).astype(int), 0)
    x_true = (treat_d * post_true).astype(float)
    x_true_tilde = _two_way_demean(x_true, city, year_group)
    b_hat = float(np.dot(x_true_tilde, y_tilde) / np.dot(x_true_tilde, x_true_tilde))

    coefs = np.zeros(n_perm, dtype=float)
    for i in range(n_perm):
        permuted = pilot_values.copy()
        rng.shuffle(permuted)
        pilot_map = dict(zip(treated_cities, permuted))
        py_perm = np.array([pilot_map.get(c, 9999) for c in city], dtype=int)
        post = np.where(treat_d == 1, (year_int >= py_perm).astype(int), 0)
        x = (treat_d * post).astype(float)
        x_tilde = _two_way_demean(x, city, year_group)
        denom = float(np.dot(x_tilde, x_tilde))
        coefs[i] = float(np.dot(x_tilde, y_tilde) / denom) if denom > 0 else np.nan

    # Two-sided empirical p-value.
    valid = np.isfinite(coefs)
    coefs = coefs[valid]
    p_emp = float((np.sum(np.abs(coefs) >= abs(b_hat)) + 1.0) / (len(coefs) + 1.0))
    return b_hat, p_emp, float(coefs.mean()), coefs


def main() -> int:
    ap = argparse.ArgumentParser()
    ap.add_argument("--panel", default="media_project/out/env_policy_city_year_with_wechat.csv")
    ap.add_argument("--policy", default="lowcarbon")
    ap.add_argument("--outcome", default="good_day_share")
    ap.add_argument("--days-threshold", type=int, default=330)
    ap.add_argument("--perm-n", type=int, default=300)
    ap.add_argument("--perm-seed", type=int, default=7)
    ap.add_argument("--fig-out", default="media_project/paper/figures/Figure_A3.png")
    args = ap.parse_args()

    df = prepare_panel(args.panel)
    df = df.dropna(subset=[args.outcome, "days"]).copy()
    df["days"] = pd.to_numeric(df["days"], errors="coerce")
    df = df.dropna(subset=["days"]).copy()
    df = df[df["days"] >= args.days_threshold].copy()

    # Stacked DID.
    stacked = estimate_stacked_did(df, policy=args.policy, outcome=args.outcome, window=(-8, 6))

    # Placebo shifts.
    placebo = {}
    for s in [2, 3, 4]:
        b, se, n = placebo_shift_did_pre_sample(df, policy=args.policy, outcome=args.outcome, shift_years=s)
        placebo[s] = (b, se, n)

    # Permutation test.
    b_hat, p_emp, mean_perm, coefs = permute_pilot_years(
        df, policy=args.policy, outcome=args.outcome, n_perm=args.perm_n, seed=args.perm_seed
    )

    fig_path = Path(args.fig_out)
    fig_path.parent.mkdir(parents=True, exist_ok=True)

    plt.figure(figsize=(7.2, 4.2))
    plt.hist(coefs, bins=30, alpha=0.85, color="#4C78A8", edgecolor="white")
    plt.axvline(b_hat, color="#E45756", linewidth=2.0, label=f"Observed coef = {b_hat:.3f}")
    plt.title(f"Permutation test: {args.policy} pilot-year reassignment ({len(coefs)} draws)")
    plt.xlabel("Placebo DID coefficient (TWFE city+year FE)")
    plt.ylabel("Count")
    plt.legend(frameon=False)
    plt.tight_layout()
    plt.savefig(fig_path, dpi=300)
    plt.close()

    out = {
        "stacked_did": {"coef": stacked.coef, "se": stacked.se, "n": stacked.nobs, "cohorts": stacked.cohorts, "window": stacked.window},
        "placebo_shift_pre": {str(k): {"coef": v[0], "se": v[1], "n": v[2]} for k, v in placebo.items()},
        "permutation": {"observed_coef": b_hat, "p_empirical_two_sided": p_emp, "mean_perm": mean_perm, "n_perm": int(len(coefs))},
        "figure": str(fig_path),
    }

    out_json = Path("media_project/reports/appendix_robustness_extra.json")
    out_json.parent.mkdir(parents=True, exist_ok=True)
    out_json.write_text(json.dumps(out, ensure_ascii=False, indent=2) + "\n", encoding="utf-8")

    print("STACKED_DID", out["stacked_did"])
    print("PLACEBO_SHIFT_PRE", out["placebo_shift_pre"])
    print("PERMUTATION", out["permutation"])
    print("FIGURE", out["figure"])
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
